import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor
dnd_df = pd.read_csv("super_heroes_dnd_v3a.csv")
dnd_df.head()
ID | Name | Gender | Race | Height | Publisher | Alignment | Weight | STR | DEX | CON | INT | WIS | CHA | Level | HP | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A001 | A-Bomb | Male | Human | 203.0 | Marvel Comics | good | 441.0 | 18 | 11 | 17 | 12 | 13 | 11 | 1 | 7 |
1 | A002 | Abe Sapien | Male | Icthyo Sapien | 191.0 | Dark Horse Comics | good | 65.0 | 16 | 17 | 10 | 13 | 15 | 11 | 8 | 72 |
2 | A004 | Abomination | Male | Human / Radiation | 203.0 | Marvel Comics | bad | 441.0 | 13 | 14 | 13 | 10 | 18 | 15 | 15 | 135 |
3 | A009 | Agent 13 | Female | NaN | 173.0 | Marvel Comics | good | 61.0 | 15 | 18 | 16 | 16 | 17 | 10 | 14 | 140 |
4 | A015 | Alex Mercer | Male | Human | NaN | Wildstorm | bad | NaN | 14 | 17 | 13 | 12 | 10 | 11 | 9 | 72 |
dnd_df.dtypes
ID object Name object Gender object Race object Height float64 Publisher object Alignment object Weight float64 STR int64 DEX int64 CON int64 INT int64 WIS int64 CHA int64 Level int64 HP int64 dtype: object
dnd_df_2 = dnd_df.iloc[:, np.r_[8:14, 15]]
dnd_df_2
# Alternatively, use:
# dnd_df.iloc[:, list(range(8,14)) + [15]]
# Note the end range
# Or just use:
dnd_df.iloc[:, [8, 9, 10, 11, 12, 13, 15]]
STR | DEX | CON | INT | WIS | CHA | HP | |
---|---|---|---|---|---|---|---|
0 | 18 | 11 | 17 | 12 | 13 | 11 | 7 |
1 | 16 | 17 | 10 | 13 | 15 | 11 | 72 |
2 | 13 | 14 | 13 | 10 | 18 | 15 | 135 |
3 | 15 | 18 | 16 | 16 | 17 | 10 | 140 |
4 | 14 | 17 | 13 | 12 | 10 | 11 | 72 |
... | ... | ... | ... | ... | ... | ... | ... |
729 | 8 | 14 | 17 | 13 | 14 | 15 | 64 |
730 | 17 | 12 | 11 | 11 | 14 | 10 | 56 |
731 | 18 | 10 | 14 | 17 | 10 | 10 | 49 |
732 | 11 | 11 | 10 | 12 | 15 | 16 | 36 |
733 | 16 | 12 | 18 | 15 | 15 | 16 | 81 |
734 rows × 7 columns
import sklearn
from sklearn.model_selection import train_test_split
predictors = ["STR", "DEX", "CON", "INT", "WIS", "HP"]
outcome = "HP"
X = dnd_df_2.drop(columns = ["HP"])
y = dnd_df_2["HP"]
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.4, random_state = 666)
train_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
650 | 17 | 14 | 16 | 16 | 15 | 17 |
479 | 8 | 18 | 16 | 10 | 14 | 17 |
271 | 9 | 12 | 17 | 10 | 15 | 17 |
647 | 9 | 18 | 16 | 10 | 17 | 13 |
307 | 12 | 16 | 14 | 18 | 15 | 13 |
len(train_X)
440
train_y.head()
650 117 479 120 271 72 647 117 307 100 Name: HP, dtype: int64
len(train_y)
440
valid_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
389 | 10 | 16 | 15 | 13 | 11 | 10 |
131 | 18 | 10 | 12 | 10 | 16 | 18 |
657 | 10 | 11 | 12 | 11 | 18 | 14 |
421 | 16 | 13 | 11 | 16 | 13 | 11 |
160 | 12 | 16 | 17 | 18 | 11 | 15 |
len(valid_X)
294
valid_y.head()
389 45 131 42 657 63 421 64 160 54 Name: HP, dtype: int64
len(valid_y)
294
full_tree = DecisionTreeRegressor(random_state = 666)
full_tree
DecisionTreeRegressor(random_state=666)
full_tree_fit = full_tree.fit(train_X, train_y)
Plot the tree
from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- feature_3 <= 10.50 | | | |--- feature_5 <= 16.00 | | | | |--- feature_2 <= 13.50 | | | | | |--- value: [18.00] | | | | |--- feature_2 > 13.50 | | | | | |--- feature_0 <= 11.50 | | | | | | |--- value: [48.00] | | | | | |--- feature_0 > 11.50 | | | | | | |--- value: [50.00] | | | |--- feature_5 > 16.00 | | | | |--- value: [9.00] | | |--- feature_3 > 10.50 | | | |--- feature_2 <= 13.50 | | | | |--- feature_0 <= 17.50 | | | | | |--- feature_0 <= 12.00 | | | | | | |--- value: [40.00] | | | | | |--- feature_0 > 12.00 | | | | | | |--- truncated branch of depth 4 | | | | |--- feature_0 > 17.50 | | | | | |--- value: [90.00] | | | |--- feature_2 > 13.50 | | | | |--- feature_0 <= 10.50 | | | | | |--- feature_3 <= 13.50 | | | | | | |--- value: [56.00] | | | | | |--- feature_3 > 13.50 | | | | | | |--- value: [50.00] | | | | |--- feature_0 > 10.50 | | | | | |--- feature_2 <= 17.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_2 > 17.50 | | | | | | |--- value: [120.00] | |--- feature_3 > 14.50 | | |--- feature_5 <= 10.50 | | | |--- feature_0 <= 17.50 | | | | |--- feature_4 <= 16.00 | | | | | |--- value: [112.00] | | | | |--- feature_4 > 16.00 | | | | | |--- value: [84.00] | | | |--- feature_0 > 17.50 | | | | |--- value: [49.00] | | |--- feature_5 > 10.50 | | | |--- feature_4 <= 11.50 | | | | |--- feature_3 <= 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [6.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- truncated branch of depth 2 | | | | |--- feature_3 > 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [54.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- value: [20.00] | | | |--- feature_4 > 11.50 | | | | |--- feature_4 <= 17.50 | | | | | |--- feature_5 <= 12.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_5 > 12.50 | | | | | | |--- truncated branch of depth 6 | | | | |--- feature_4 > 17.50 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 17.50 | | | | | | |--- value: [50.00] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- feature_5 <= 10.50 | | | | |--- feature_2 <= 12.50 | | | | | |--- feature_4 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_4 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | | |--- feature_2 > 12.50 | | | | | |--- feature_2 <= 16.50 | | | | | | |--- truncated branch of depth 7 | | | | | |--- feature_2 > 16.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_5 > 10.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_3 <= 10.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_3 > 10.50 | | | | | | |--- truncated branch of depth 13 | | | | |--- feature_5 > 17.50 | | | | | |--- feature_2 <= 15.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_2 > 15.50 | | | | | | |--- truncated branch of depth 8 | | |--- feature_2 > 17.50 | | | |--- feature_1 <= 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_0 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_0 > 16.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_0 <= 17.50 | | | | | | |--- truncated branch of depth 5 | | | | | |--- feature_0 > 17.50 | | | | | | |--- value: [8.00] | | | |--- feature_1 > 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_3 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 11.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_1 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_1 > 16.50 | | | | | | |--- truncated branch of depth 4 | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- feature_3 <= 16.50 | | | | |--- feature_3 <= 13.50 | | | | | |--- feature_1 <= 14.50 | | | | | | |--- truncated branch of depth 6 | | | | | |--- feature_1 > 14.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_3 > 13.50 | | | | | |--- feature_1 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_1 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_3 > 16.50 | | | | |--- feature_2 <= 17.00 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- value: [72.00] | | | | | |--- feature_3 > 17.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_2 > 17.00 | | | | | |--- value: [117.00] | | |--- feature_0 > 14.50 | | | |--- feature_0 <= 15.50 | | | | |--- feature_1 <= 14.50 | | | | | |--- feature_5 <= 16.00 | | | | | | |--- value: [9.00] | | | | | |--- feature_5 > 16.00 | | | | | | |--- value: [6.00] | | | | |--- feature_1 > 14.50 | | | | | |--- value: [28.00] | | | |--- feature_0 > 15.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_2 <= 12.50 | | | | | | |--- truncated branch of depth 3 | | | | | |--- feature_2 > 12.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_5 > 17.50 | | | | | |--- value: [112.00]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
[Text(0.45454545454545453, 0.9285714285714286, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.1690340909090909, 0.7857142857142857, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.07670454545454546, 0.6428571428571429, 'INT <= 10.5\nsquared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.03409090909090909, 0.5, 'CHA <= 16.0\nsquared_error = 325.688\nsamples = 4\nvalue = 31.25'), Text(0.022727272727272728, 0.35714285714285715, 'CON <= 13.5\nsquared_error = 214.222\nsamples = 3\nvalue = 38.667'), Text(0.011363636363636364, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 18.0'), Text(0.03409090909090909, 0.21428571428571427, 'STR <= 11.5\nsquared_error = 1.0\nsamples = 2\nvalue = 49.0'), Text(0.022727272727272728, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 9.0'), Text(0.11931818181818182, 0.5, 'CON <= 13.5\nsquared_error = 519.855\nsamples = 17\nvalue = 72.294'), Text(0.09090909090909091, 0.35714285714285715, 'STR <= 17.5\nsquared_error = 187.484\nsamples = 8\nvalue = 58.375'), Text(0.07954545454545454, 0.21428571428571427, 'STR <= 12.0\nsquared_error = 50.98\nsamples = 7\nvalue = 53.857'), Text(0.06818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.09090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.10227272727272728, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 90.0'), Text(0.14772727272727273, 0.35714285714285715, 'STR <= 10.5\nsquared_error = 490.0\nsamples = 9\nvalue = 84.667'), Text(0.125, 0.21428571428571427, 'INT <= 13.5\nsquared_error = 9.0\nsamples = 2\nvalue = 53.0'), Text(0.11363636363636363, 0.07142857142857142, '\n (...) \n'), Text(0.13636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.17045454545454544, 0.21428571428571427, 'CON <= 17.5\nsquared_error = 259.061\nsamples = 7\nvalue = 93.714'), Text(0.1590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.18181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.26136363636363635, 0.6428571428571429, 'CHA <= 10.5\nsquared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.2159090909090909, 0.5, 'STR <= 17.5\nsquared_error = 664.222\nsamples = 3\nvalue = 81.667'), Text(0.20454545454545456, 0.35714285714285715, 'WIS <= 16.0\nsquared_error = 196.0\nsamples = 2\nvalue = 98.0'), Text(0.19318181818181818, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0'), Text(0.2159090909090909, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 84.0'), Text(0.22727272727272727, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 49.0'), Text(0.3068181818181818, 0.5, 'WIS <= 11.5\nsquared_error = 534.144\nsamples = 19\nvalue = 33.526'), Text(0.26136363636363635, 0.35714285714285715, 'INT <= 16.5\nsquared_error = 319.04\nsamples = 5\nvalue = 19.6'), Text(0.23863636363636365, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 2.667\nsamples = 3\nvalue = 8.0'), Text(0.22727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.25, 0.07142857142857142, '\n (...) \n'), Text(0.2840909090909091, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 289.0\nsamples = 2\nvalue = 37.0'), Text(0.2727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.29545454545454547, 0.07142857142857142, '\n (...) \n'), Text(0.3522727272727273, 0.35714285714285715, 'WIS <= 17.5\nsquared_error = 516.964\nsamples = 14\nvalue = 38.5'), Text(0.32954545454545453, 0.21428571428571427, 'CHA <= 12.5\nsquared_error = 382.41\nsamples = 10\nvalue = 46.3'), Text(0.3181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.3409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.375, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 321.0\nsamples = 4\nvalue = 19.0'), Text(0.36363636363636365, 0.07142857142857142, '\n (...) \n'), Text(0.38636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.7400568181818182, 0.7857142857142857, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.5795454545454546, 0.6428571428571429, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.48863636363636365, 0.5, 'CHA <= 10.5\nsquared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.4431818181818182, 0.35714285714285715, 'CON <= 12.5\nsquared_error = 1278.057\nsamples = 43\nvalue = 78.419'), Text(0.42045454545454547, 0.21428571428571427, 'WIS <= 11.5\nsquared_error = 1262.102\nsamples = 14\nvalue = 57.429'), Text(0.4090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.4318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.4659090909090909, 0.21428571428571427, 'CON <= 16.5\nsquared_error = 970.385\nsamples = 29\nvalue = 88.552'), Text(0.45454545454545453, 0.07142857142857142, '\n (...) \n'), Text(0.4772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.5340909090909091, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 1415.068\nsamples = 280\nvalue = 67.982'), Text(0.5113636363636364, 0.21428571428571427, 'INT <= 10.5\nsquared_error = 1402.723\nsamples = 241\nvalue = 66.593'), Text(0.5, 0.07142857142857142, '\n (...) \n'), Text(0.5227272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.5568181818181818, 0.21428571428571427, 'CON <= 15.5\nsquared_error = 1405.784\nsamples = 39\nvalue = 76.564'), Text(0.5454545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.5681818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.6704545454545454, 0.5, 'DEX <= 15.5\nsquared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.625, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 779.741\nsamples = 21\nvalue = 45.857'), Text(0.6022727272727273, 0.21428571428571427, 'STR <= 16.5\nsquared_error = 788.29\nsamples = 10\nvalue = 54.9'), Text(0.5909090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.6136363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.6477272727272727, 0.21428571428571427, 'STR <= 17.5\nsquared_error = 630.05\nsamples = 11\nvalue = 37.636'), Text(0.6363636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.6590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.7159090909090909, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 1508.649\nsamples = 15\nvalue = 72.133'), Text(0.6931818181818182, 0.21428571428571427, 'INT <= 11.5\nsquared_error = 584.889\nsamples = 6\nvalue = 37.333'), Text(0.6818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.7045454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.7386363636363636, 0.21428571428571427, 'DEX <= 16.5\nsquared_error = 778.889\nsamples = 9\nvalue = 95.333'), Text(0.7272727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.75, 0.07142857142857142, '\n (...) \n'), Text(0.9005681818181818, 0.6428571428571429, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.8465909090909091, 0.5, 'INT <= 16.5\nsquared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.8068181818181818, 0.35714285714285715, 'INT <= 13.5\nsquared_error = 1391.959\nsamples = 21\nvalue = 54.429'), Text(0.7840909090909091, 0.21428571428571427, 'DEX <= 14.5\nsquared_error = 1660.628\nsamples = 11\nvalue = 67.909'), Text(0.7727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.7954545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.8295454545454546, 0.21428571428571427, 'DEX <= 11.5\nsquared_error = 676.64\nsamples = 10\nvalue = 39.6'), Text(0.8181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.8409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.35714285714285715, 'CON <= 17.0\nsquared_error = 262.16\nsamples = 5\nvalue = 92.8'), Text(0.875, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 144.688\nsamples = 4\nvalue = 86.75'), Text(0.8636363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.8977272727272727, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 117.0'), Text(0.9545454545454546, 0.5, 'STR <= 15.5\nsquared_error = 892.889\nsamples = 12\nvalue = 45.667'), Text(0.9318181818181818, 0.35714285714285715, 'DEX <= 14.5\nsquared_error = 76.5\nsamples = 4\nvalue = 13.0'), Text(0.9204545454545454, 0.21428571428571427, 'CHA <= 16.0\nsquared_error = 2.0\nsamples = 3\nvalue = 8.0'), Text(0.9090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.9318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.9431818181818182, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 28.0'), Text(0.9772727272727273, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 500.75\nsamples = 8\nvalue = 62.0'), Text(0.9659090909090909, 0.21428571428571427, 'CON <= 12.5\nsquared_error = 164.122\nsamples = 7\nvalue = 54.857'), Text(0.9545454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.9772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.9886363636363636, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0')]
Export tree and convert to a picture file.
from sklearn.tree import export_graphviz
dot_data = export_graphviz(full_tree, out_file='full_tree.dot', feature_names = train_X.columns)
Not very useful.
small_tree = DecisionTreeRegressor(random_state = 666, max_depth = 3, min_samples_split = 25)
small_tree
DecisionTreeRegressor(max_depth=3, min_samples_split=25, random_state=666)
small_tree_fit = small_tree.fit(train_X, train_y)
Plot the tree
# For illustration:
# from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation_2 = tree.export_text(small_tree)
print(text_representation_2)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- value: [64.48] | |--- feature_3 > 14.50 | | |--- value: [40.09] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- value: [69.37] | | |--- feature_2 > 17.50 | | | |--- value: [56.81] | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- value: [61.81] | | |--- feature_0 > 14.50 | | | |--- value: [45.67]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(small_tree, feature_names = train_X.columns)
[Text(0.4090909090909091, 0.875, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.18181818181818182, 0.625, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.09090909090909091, 0.375, 'squared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.2727272727272727, 0.375, 'squared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.6363636363636364, 0.625, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.45454545454545453, 0.375, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.36363636363636365, 0.125, 'squared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.5454545454545454, 0.125, 'squared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.8181818181818182, 0.375, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.7272727272727273, 0.125, 'squared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.9090909090909091, 0.125, 'squared_error = 892.889\nsamples = 12\nvalue = 45.667')]
Export tree and convert to a picture file.
# For illustration
# from sklearn.tree import export_graphviz
dot_data_2 = export_graphviz(small_tree, out_file='small_tree.dot', feature_names = train_X.columns)
Much better.
On the training set
train_y_pred = small_tree.predict(train_X)
train_y_pred
array([69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 64.47619048, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 45.66666667, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 45.66666667, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703])
Get the RMSE for the training set
mse_small_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred)
mse_small_tree_train
1320.9654960245605
import math
rmse_small_tree_train = math.sqrt(mse_small_tree_train)
rmse_small_tree_train
36.34508902210256
If using the dmba package, install it first:
pip install dmba
import dmba
from dmba import regressionSummary
no display found. Using non-interactive Agg backend
regressionSummary(train_y, train_y_pred)
Regression statistics Mean Error (ME) : 0.0000 Root Mean Squared Error (RMSE) : 36.3451 Mean Absolute Error (MAE) : 30.4192 Mean Percentage Error (MPE) : -75.0378 Mean Absolute Percentage Error (MAPE) : 103.4166
On the validation set
valid_y_pred = small_tree.predict(valid_X)
valid_y_pred
array([69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 56.80555556, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 61.80769231, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 64.47619048, 64.47619048, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 64.47619048, 69.37151703, 45.66666667, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 45.66666667, 64.47619048, 69.37151703, 56.80555556, 69.37151703])
Get the RMSE for the validation set
mse_small_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred)
mse_small_tree_valid
1353.8427521618253
import math
rmse_small_tree_valid = math.sqrt(mse_small_tree_valid)
rmse_small_tree_valid
36.79460221502368
If using the dmba package, install it first:
pip install dmba
Then load the library
import dmba
from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred)
Regression statistics Mean Error (ME) : 4.2030 Root Mean Squared Error (RMSE) : 36.7946 Mean Absolute Error (MAE) : 30.2375 Mean Percentage Error (MPE) : -48.2303 Mean Absolute Percentage Error (MAPE) : 80.0992
from sklearn.model_selection import GridSearchCV
param_grid = {"max_depth": [2, 3, 4, 5],
"min_samples_split": [10, 20, 30],
"min_impurity_decrease": [0, 0.001, 0.002]}
grid_search = GridSearchCV(DecisionTreeRegressor(random_state = 666), param_grid, cv = 10)
grid_search.fit(train_X, train_y)
GridSearchCV(cv=10, estimator=DecisionTreeRegressor(random_state=666), param_grid={'max_depth': [2, 3, 4, 5], 'min_impurity_decrease': [0, 0.001, 0.002], 'min_samples_split': [10, 20, 30]})
print("Initial parameters:", grid_search.best_params_)
Initial parameters: {'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
grid_search.best_score_
-0.0933578084568057
grid_search.best_params_
{'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
best_tree = grid_search.best_estimator_
best_tree
DecisionTreeRegressor(max_depth=2, min_impurity_decrease=0, min_samples_split=10, random_state=666)
dot_data_3 = export_graphviz(best_tree, out_file='best_tree.dot', feature_names = train_X.columns)
On the training set
train_y_best_pred = best_tree.predict(train_X)
train_y_best_pred
array([68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 40.09090909, 64.47619048, 68.11142061, 56.71052632, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061])
Get the RMSE for the training set
mse_best_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_best_pred)
mse_best_tree_train
1337.4509437318577
import math
rmse_best_tree_train = math.sqrt(mse_best_tree_train)
rmse_best_tree_train
36.57117640617892
# If using the dmba package, install it first:
# pip install dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(train_y, train_y_best_pred)
Regression statistics Mean Error (ME) : -0.0000 Root Mean Squared Error (RMSE) : 36.5712 Mean Absolute Error (MAE) : 30.6315 Mean Percentage Error (MPE) : -76.2539 Mean Absolute Percentage Error (MAPE) : 104.6671
On the validation set
valid_y_best_pred = best_tree.predict(valid_X)
valid_y_best_pred
array([68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 56.71052632, 64.47619048, 68.11142061, 68.11142061, 68.11142061])
Get the RMSE for the validation set
mse_best_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_best_pred)
mse_best_tree_valid
1320.469997098233
import math
rmse_best_tree_valid = math.sqrt(mse_best_tree_valid)
rmse_best_tree_valid
36.33827179570092
# If using the dmba package, install it first:
# pip install dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_best_pred)
Regression statistics Mean Error (ME) : 3.8074 Root Mean Squared Error (RMSE) : 36.3383 Mean Absolute Error (MAE) : 29.9632 Mean Percentage Error (MPE) : -49.3930 Mean Absolute Percentage Error (MAPE) : 80.5791
new_dnd_df = pd.read_csv("new_records_dnd.csv")
new_dnd_df
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 |
1 | 17 | 9 | 17 | 18 | 11 | 7 |
Using the small tree
new_records_dnd_small_pred = small_tree.predict(new_dnd_df)
new_records_dnd_small_pred
array([69.37151703, 40.09090909])
import pandas as pd
dnd_small_tree_prediction = pd.DataFrame(new_records_dnd_small_pred,
columns = ["Prediction"])
dnd_small_tree_prediction
Prediction | |
---|---|
0 | 69.371517 |
1 | 40.090909 |
Using the best tree
new_records_dnd_best_pred = best_tree.predict(new_dnd_df)
new_records_dnd_best_pred
array([68.11142061, 40.09090909])
dnd_best_tree_prediction = pd.DataFrame(new_records_dnd_best_pred,
columns = ["Prediction"])
dnd_best_tree_prediction
Prediction | |
---|---|
0 | 68.111421 |
1 | 40.090909 |